import json
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from tqdm import tqdm
import pandas as pd
import argparse
import os
import sys
from typing import Dict, List, Optional
import random  # For sampling distractor reasons

def parse_args() -> argparse.Namespace:
    """Parse command-line arguments for parallel execution."""
    parser = argparse.ArgumentParser(
        description="Run persona-based topic classification over a slice of the dataset.",
    )
    parser.add_argument("--start", type=int, default=0, help="Start index (inclusive) of the slice.")
    parser.add_argument("--end", type=int, default=None, help="End index (inclusive) of the slice.")
    parser.add_argument("--output_dir", type=str, default="topic_results", help="Directory to write JSON results.")
    parser.add_argument("--csv_path", type=str, required=True, help="Path to the input CSV with columns video_id,story")
    parser.add_argument("--annotation_path", type=str, default="action_annotation.json", help="Path to action_annotation.json for sampling distractor actions")
    return parser.parse_args()

# ---------------------------------------------------------------------------
# Zero-shot system prompt for best-reason selection
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
    "You will be given the STORY of a video advertisement and a numbered list "
    "of candidate ACTIONS that a viewer might take after watching it. "
    "Choose EXACTLY ONE action that best fits the story. "
    "Return EXACTLY two lines:\n"
    "Answer: <action>\nReason: <brief justification>."
)

# CSV expected columns: video_id, story, reasons (JSON list or ';'-separated)

def main():
    args = parse_args()
    os.makedirs(args.output_dir, exist_ok=True)

    # Setup Azure OpenAI client
    api_version = "2024-02-15-preview"
    config_dict: Dict[str, str] = {
        "api_key": os.getenv("OPENAI_API_KEY", "YOUR_OPENAI_API_KEY"),
        "api_version": api_version,
        "azure_endpoint": os.getenv("AZURE_OPENAI_ENDPOINT", "https://your-azure-openai-endpoint/"),
    }

    # --------------------------------------------------------------
    # Load Qwen chat model once
    # --------------------------------------------------------------
    global model, tokenizer
    model_name = "Qwen/Qwen3-32B" #change to meta-llama/Llama-3.3-70B-Instruct for LlaMA
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype="auto",
        device_map="auto",
        load_in_4bit=True,
    )

    # -------------------------------------------------------------------
    # Load full action annotation to draw distractor actions
    # -------------------------------------------------------------------
    try:
        with open(args.annotation_path, "r") as f:
            annotation_data = json.load(f)
    except Exception as e:
        print(f"Error reading annotation JSON {args.annotation_path}: {e}")
        sys.exit(1)

    # Flatten all actions into a single pool we can sample from
    all_actions_pool = [act for acts in annotation_data.values() for act in acts]

    # Load CSV data
    try:
        df = pd.read_csv(args.csv_path)
    except Exception as e:
        print(f"Error reading CSV {args.csv_path}: {e}")
        sys.exit(1)

    all_records = df.to_dict(orient='records')

    # Determine slice for this run
    start_idx = args.start
    end_idx = len(all_records) - 1 if args.end is None else min(args.end, len(all_records) - 1)
    slice_records = all_records[start_idx : end_idx + 1]

    print(f"Processing slice {start_idx}–{end_idx} (n={len(slice_records)})")

    results = []
    output_path = os.path.join(args.output_dir, f"topic_results_{start_idx}_{end_idx}.json")

    for rec in tqdm(slice_records, desc=f"Persona-Topic Eval {start_idx}-{end_idx}"):
        try:
            video_id = str(rec.get('video_id', '')).strip()
            story_text = rec.get('story', '')
            # -------------------------------------------------------------------
            # Retrieve correct actions for this video
            # Priority: annotation file > CSV column fallback
            # -------------------------------------------------------------------

            correct_actions = []

            # 1) annotation JSON
            if video_id in annotation_data:
                correct_actions = annotation_data[video_id]

            # 2) fallback CSV column
            if not correct_actions:
                actions_raw = rec.get('reasons', '')  # fallback legacy column
                try:
                    correct_actions = json.loads(actions_raw) if isinstance(actions_raw, str) else actions_raw
                except Exception:
                    correct_actions = [r.strip() for r in str(actions_raw).split(';') if r.strip()]

            # Clean list
            if isinstance(correct_actions, str):
                correct_actions = [correct_actions]
            correct_actions = [r for r in correct_actions if r]

            if not correct_actions:
                print(f"No actions for id {video_id}; skipping")
                continue

            # Build candidate list: 5 correct + 25 random distractors
            distractor_pool = [a for a in all_actions_pool if a not in correct_actions]
            num_distractors = 25 if len(distractor_pool) >= 25 else len(distractor_pool)
            distractor_actions = random.sample(distractor_pool, num_distractors)

            candidate_actions = correct_actions + distractor_actions
            random.shuffle(candidate_actions)

            cleaned_text = ' '.join(str(story_text).split()).replace('\n', '').replace('\f', '')

            # Build prompt with candidate actions list
            actions_block = "\n".join(f"{i+1}. {a}" for i, a in enumerate(candidate_actions))
            user_content = (
                f"Story:\n{cleaned_text}\n\nList of actions:\n{actions_block}\n\n"
                "Return exactly two lines:\nAnswer: <action>\nReason: <brief justification>"
            )

            messages = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": user_content},
            ]

            try:
                # Qwen inference
                input_ids = tokenizer.apply_chat_template(
                    messages,
                    tokenize=True,
                    add_generation_prompt=True,
                    return_tensors="pt",
                    enable_thinking=False,
                ).to(model.device)

                with torch.no_grad():
                    outputs = model.generate(
                        input_ids=input_ids,
                        max_new_tokens=120,
                        temperature=0.0,
                        do_sample=False,
                    )

                raw_resp = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True).strip()

                # Try to extract after 'Answer:' if provided
                import re as _re
                ans_match = _re.search(r"(?i)^answer:\s*(.+)$", raw_resp, _re.MULTILINE)
                chosen_action = ans_match.group(1).strip() if ans_match else raw_resp.strip()

                # Extract optional justification
                just_match = _re.search(r"(?i)^reason:\s*(.+)$", raw_resp, _re.MULTILINE)
                justification = just_match.group(1).strip() if just_match else ""

                # If answer is a digit, map to candidate actions
                if chosen_action.isdigit():
                    idx_int = int(chosen_action)
                    if 1 <= idx_int <= len(candidate_actions):
                        chosen_action = candidate_actions[idx_int-1]
            except Exception as e:
                print(f"Error during OpenAI call for key {video_id}: {e}")
                chosen_action = "error_api"
                justification = ""

            # Store results
            result_item = {
                'video_id': video_id,
                'url': f"https://www.youtube.com/watch?v={video_id}" if video_id else "",
                'story': cleaned_text,
                'predicted_action': chosen_action,
                'explanation': justification,
                'candidate_actions': candidate_actions,
                'correct_actions': correct_actions,
            }
            results.append(result_item)
            
            # Incremental save
            with open(output_path, 'w') as f:
                json.dump(results, f, indent=4)

        except Exception as e:
            print(f"Error processing key {video_id}: {e}")
            continue

    print(f"Finished processing. Results saved to {output_path}")

if __name__ == "__main__":
    main()




